/* Copyright 2024 Tencent Inc.  All rights reserved.

==============================================================================*/
#pragma once

#include <sstream>

#include "ksana_llm/cache_manager/block_allocator/block_allocator_interface.h"
#include "ksana_llm/runtime/infer_request.h"
#include "ksana_llm/utils/device_types.h"
#include "ksana_llm/utils/memory_allocator_interface.h"
#include "test.h"

namespace ksana_llm {

thread_local int g_cur_device_id;

// The faked memory state shared by block_allocator and block_allocator_group and memory_allocator.
struct FakedState {
  std::unordered_set<int> host_free_block_;
  std::vector<std::unordered_set<int>> device_free_block_;

  std::unordered_set<int> host_alloc_block_;
  std::vector<std::unordered_set<int>> device_alloc_block_;

  // kv_cache_contents is used to simulate kv-cache contents generated by LLM,
  // next token will be generated by these contents,
  // if copy operations have any bugs, generated sequence will give hints
  // kv_cache_contents[device_id][block_idx][token_offset]
  std::vector<std::map<int, std::vector<int>>> device_kv_cache_contents_;

  std::map<int, std::vector<int>> host_kv_cache_contents_;

  std::recursive_mutex mux_;

  int DEFAULT_KV_CONTENT = -678;
  int DEVICE_ID_OFFSET = 10000;
  int HOST_DEVICE_ID = 9;
};
static FakedState g_faked_state;

class FakedMemoryAllocator : public MemoryAllocatorInterface {
 public:
  virtual ~FakedMemoryAllocator() {}

  virtual void SetDevice(int device_id) override {
    KLLM_LOG_DEBUG << "SetDeviceId from " << g_cur_device_id << " to " << device_id;
    g_cur_device_id = device_id;
  }

  virtual void GetDevice(int* device_id) override { *device_id = g_cur_device_id; }

  virtual void Malloc(void** dev_ptr, size_t size) override {}
  virtual void MallocAsync(void** dev_ptr, size_t size, Stream stream) override {}

  virtual void MemsetAsync(void* dev_ptr, int value, size_t count, Stream stream) override {}
  virtual void Memset(void* dev_ptr, int value, size_t count) override {}

  virtual void MemcpyAsync(void* dst, const void* src, size_t count, enum MemcpyKind kind, Stream stream) override {}
  virtual void Memcpy(void* dst, const void* src, size_t count, enum MemcpyKind kind) override {}

  virtual void Free(void* dev_ptr) override {}
  virtual void FreeAsync(void* dev_ptr, Stream stream) override {}

  virtual void HostAlloc(void** host_ptr, size_t size) override {}

  virtual void HostFree(void* host_ptr) override {}
};

class FakedBlockAllocator : public BlockAllocatorInterface {
 public:
  FakedBlockAllocator(MemoryLocation location, int rank = 0,
                      std::shared_ptr<MemoryAllocatorInterface> memory_allocator = nullptr) {
    location_ = location;
    rank_ = rank;
    memory_allocator_ = memory_allocator;
  }

  virtual ~FakedBlockAllocator() {}

  virtual void PreAllocateBlocks() override {}

  virtual void Clear() override {}

  virtual Status AllocateBlocks(size_t block_num, std::vector<int>& blocks) override {
    if (location_ == MemoryLocation::LOCATION_DEVICE) {
      memory_allocator_->SetDevice(rank_);
      return AllocBlocks(g_cur_device_id, block_num, blocks, g_faked_state.device_free_block_[g_cur_device_id],
                         g_faked_state.device_alloc_block_[g_cur_device_id]);
    } else {
      std::lock_guard<std::recursive_mutex> guard(g_faked_state.mux_);
      return AllocBlocks(g_faked_state.HOST_DEVICE_ID, block_num, blocks, g_faked_state.host_free_block_,
                         g_faked_state.host_alloc_block_);
    }
  }

  virtual Status FreeBlocks(const std::vector<int>& blocks) override {
    if (location_ == MemoryLocation::LOCATION_DEVICE) {
      memory_allocator_->SetDevice(rank_);
      return FreeBlocks(g_cur_device_id, blocks, g_faked_state.device_free_block_[g_cur_device_id],
                        g_faked_state.device_alloc_block_[g_cur_device_id]);
    } else {
      return FreeBlocks(g_faked_state.HOST_DEVICE_ID, blocks, g_faked_state.host_free_block_,
                        g_faked_state.host_alloc_block_);
    }
  }

  virtual Status GetBlockPtrs(const std::vector<int>& blocks, std::vector<void*>& addrs) override {
    KLLM_CHECK_WITH_INFO(false, "GetBlockPtrs not implemented");
    return Status();
  }

  virtual void* GetBlocksBasePtr() override { return nullptr; }

  virtual int GetBlocksBaseId() override { return 0; }

  virtual size_t GetFreeBlockNumber() override {
    if (location_ == MemoryLocation::LOCATION_DEVICE) {
      return g_faked_state.device_free_block_[g_cur_device_id].size();
    } else {
      return g_faked_state.host_free_block_.size();
    }
  }

  virtual size_t GetUsedBlockNumber() override {
    if (location_ == MemoryLocation::LOCATION_DEVICE) {
      return g_faked_state.device_alloc_block_[g_cur_device_id].size();
    } else {
      return g_faked_state.host_alloc_block_.size();
    }
  }

 private:
  Status AllocBlocks(int device_id, size_t block_num, std::vector<int>& blocks, std::unordered_set<int>& free_blocks,
                     std::unordered_set<int>& used_blocks) {
    std::lock_guard<std::recursive_mutex> guard(g_faked_state.mux_);
    if (block_num > free_blocks.size()) {
      KLLM_LOG_DEBUG << "Failed to alloc on device " << device_id << ", block_num=" << block_num
                     << ", free_blocks.size()=" << free_blocks.size();
      return Status(RET_DEVICE_MEM_ALLOCATE_FAILED,
                    FormatStr("No more free blocks, expect %d, free %d", block_num, free_blocks.size()));
    }

    blocks.clear();
    blocks.reserve(block_num);
    auto it = free_blocks.begin();
    while (block_num--) {
      used_blocks.insert(*it);
      blocks.push_back(*it);
      it = free_blocks.erase(it);
    }
    std::stringstream ss_blocks;
    for (auto block_id : blocks) {
      ss_blocks << block_id << ", ";
    }
    KLLM_LOG_DEBUG << "Alloc OK on device " << device_id << ", alloc block num=" << blocks.size()
                   << ", free_blocks.size()=" << free_blocks.size() << ", used_blocks.size()=" << used_blocks.size()
                   << ", blocks=[" << ss_blocks.str() << "].";
    return Status();
  }

  Status FreeBlocks(int device_id, const std::vector<int>& blocks, std::unordered_set<int>& free_blocks,
                    std::unordered_set<int>& used_blocks) {
    std::lock_guard<std::recursive_mutex> guard(g_faked_state.mux_);
    std::stringstream ss_blocks;
    for (auto block_id : blocks) {
      ss_blocks << block_id << ", ";
    }

    KLLM_LOG_DEBUG << "Free start on device " << device_id << ", block num=" << blocks.size()
                   << ", free_blocks.size()=" << free_blocks.size() << ", used_blocks.size()=" << used_blocks.size()
                   << ", blocks=[" << ss_blocks.str() << "].";

    for (auto block_id : blocks) {
      auto it = used_blocks.find(block_id);
      if (it != used_blocks.end()) {
        free_blocks.insert(*it);
        used_blocks.erase(it);
      } else {
        assert(false);
        KLLM_CHECK_WITH_INFO(false, "Double free");
        return Status(RET_DEVICE_MEM_FREE_FAILED, fmt::format("Double free error, block id {}", block_id));
      }
    }
    KLLM_LOG_DEBUG << "Free OK on device " << device_id;
    return Status();
  }

 private:
  MemoryLocation location_ = MemoryLocation::LOCATION_HOST;
  int rank_ = -1;
  std::shared_ptr<MemoryAllocatorInterface> memory_allocator_ = nullptr;
};

class FakedBlockAllocatorGroup : public BlockAllocatorGroupInterface {
 public:
  FakedBlockAllocatorGroup(const BlockManagerConfig& block_manager_config, int tp_num) {
    block_manager_config_ = block_manager_config;

    block_token_num_ = block_manager_config_.device_allocator_config.block_token_num;
    host_block_num_ = block_manager_config_.host_allocator_config.blocks_num;
    device_block_num_ = block_manager_config_.device_allocator_config.blocks_num;

    g_cur_device_id = 0;
    device_num_ = tp_num;
    KLLM_LOG_INFO << "device_num=" << device_num_ << ", block_token_num=" << block_token_num_
                  << ", host_block_num=" << host_block_num_ << ", device_block_num=" << device_block_num_;
    KLLM_CHECK_WITH_INFO(
        (host_block_num_ <= g_faked_state.DEVICE_ID_OFFSET) && (device_block_num_ < g_faked_state.DEVICE_ID_OFFSET),
        FormatStr("block_num should be less than DEVICE_ID_OFFSET=%d", g_faked_state.DEVICE_ID_OFFSET));
    KLLM_CHECK_WITH_INFO((device_num_ <= g_faked_state.HOST_DEVICE_ID),
                         FormatStr("device_num should be less than HOST_DEVICE_ID=%d", g_faked_state.HOST_DEVICE_ID));

    memory_allocator_ = std::make_shared<FakedMemoryAllocator>();
    host_allocator_ = std::make_shared<FakedBlockAllocator>(MemoryLocation::LOCATION_HOST, 0, memory_allocator_);
    for (size_t dev_id = 0; dev_id < static_cast<size_t>(device_num_); ++dev_id) {
      dev_allocators_[dev_id] =
          std::make_shared<FakedBlockAllocator>(MemoryLocation::LOCATION_DEVICE, dev_id, memory_allocator_);
    }

    // Init free block ids and kv cache contents
    for (int i = 0; i < host_block_num_; i++) {
      int d_blk_id = g_faked_state.HOST_DEVICE_ID * g_faked_state.DEVICE_ID_OFFSET + i;
      std::vector<int> temp_kv;
      for (int j = 0; j < block_token_num_; j++) {
        temp_kv.push_back(g_faked_state.DEFAULT_KV_CONTENT);
      }
      g_faked_state.host_kv_cache_contents_[d_blk_id] = temp_kv;
      g_faked_state.host_free_block_.insert(d_blk_id);
    }

    for (int d_i = 0; d_i < device_num_; d_i++) {
      std::map<int, std::vector<int>> temp_block_kv;
      g_faked_state.device_kv_cache_contents_.push_back(temp_block_kv);
      std::unordered_set<int> temp_blocks;
      g_faked_state.device_free_block_.push_back(temp_blocks);
      g_faked_state.device_alloc_block_.push_back(temp_blocks);
      for (int b_i = 0; b_i < device_block_num_; b_i++) {
        int d_blk_id = g_faked_state.DEVICE_ID_OFFSET * (d_i + 1) + b_i;
        std::vector<int> temp_kv;
        for (int j = 0; j < block_token_num_; j++) {
          temp_kv.push_back(g_faked_state.DEFAULT_KV_CONTENT);
        }
        g_faked_state.device_kv_cache_contents_[d_i][d_blk_id] = temp_kv;
        g_faked_state.device_free_block_[d_i].insert(d_blk_id);
      }
    }

    KLLM_LOG_DEBUG << "BlockManagerSimulator started";
  }

  virtual ~FakedBlockAllocatorGroup() {}

  virtual std::shared_ptr<BlockAllocatorInterface> GetHostBlockAllocator() const override { return host_allocator_; }

  virtual std::vector<int> GetBlockAllocatorDevices() const override {
    std::vector<int> devices;
    for (size_t dev_id = 0; dev_id < static_cast<size_t>(device_num_); ++dev_id) {
      devices.push_back(dev_id);
    }
    return devices;
  }

  virtual std::shared_ptr<BlockAllocatorInterface> GetDeviceBlockAllocator(int device_id = 0) const override {
    return dev_allocators_.at(device_id);
  }

  virtual Status SwapIn(int device_id, int device_block_id, int host_block_id) override {
    KLLM_LOG_DEBUG << "SwapIn on device " << device_id << ", host_block_id=" << host_block_id
                   << ", device_block_id=" << device_block_id;

    std::vector<int> device_blocks{device_block_id};
    std::vector<int> host_blocks{host_block_id};

    // Copy kv-cache contents
    CopyKvCacheContents(host_blocks, device_blocks, g_faked_state.host_kv_cache_contents_,
                        g_faked_state.device_kv_cache_contents_[device_id]);

    // Reset contents in block
    ResetKvCacheContents(host_blocks, g_faked_state.host_kv_cache_contents_);
    // TODO(robertyuan): Simulate communication delay
    std::this_thread::sleep_for(std::chrono::microseconds(1));

    KLLM_LOG_DEBUG << "SwapIn finished on device " << device_id << ". host_block_id=" << host_block_id
                   << ", device_block_id=" << device_block_id;
    stat_.swapin_succ_num += host_blocks.size();
    return Status();

    return Status();
  }

  virtual Status SwapOut(int device_id, int host_block_id, int device_block_id) override {
    KLLM_LOG_DEBUG << "SwapOut on device " << device_id << " device_blockid=" << device_block_id
                   << ", host_block_id=" << host_block_id;

    std::vector<int> device_blocks{device_block_id};
    std::vector<int> host_blocks{host_block_id};

    // Copy kv-cache contents
    CopyKvCacheContents(device_blocks, host_blocks, g_faked_state.device_kv_cache_contents_[device_id],
                        g_faked_state.host_kv_cache_contents_);

    // Reset contents in block
    ResetKvCacheContents(device_blocks, g_faked_state.device_kv_cache_contents_[device_id]);
    // TODO(robertyuan): Simulate communication delay
    std::this_thread::sleep_for(std::chrono::microseconds(1));

    stat_.swapout_succ_num += device_blocks.size();
    return Status();

    return Status();
  }

  const BlockManagerConfig& GetBlockManagerConfig() const { return block_manager_config_; }

 public:
  // Functions not in BlockManagerInferface
  void CollectKvCacheContent(std::shared_ptr<InferRequest>& req, std::vector<int>& kv_cache_contents) {
    std::lock_guard<std::recursive_mutex> guard(g_faked_state.mux_);
    int kv_cache_token_num = req->output_tokens.size();
    // Collect kv cache content from device 0
    CollectKvCacheContentFromDevice(0, req->kv_cache_blocks[0], kv_cache_token_num, kv_cache_contents);

    // Check all devices have some contents
    if (device_num_ > 1) {
      for (int d_i = 1; d_i < device_num_; d_i++) {
        std::vector<int> temp_contents;
        CollectKvCacheContentFromDevice(d_i, req->kv_cache_blocks[d_i], kv_cache_token_num, temp_contents);
        // Check results
        for (int i = 0; i < kv_cache_token_num; i++) {
          KLLM_CHECK_WITH_INFO(
              kv_cache_contents[i] == temp_contents[i] - d_i,
              FormatStr("Kv cache content diff between device 0 and device %d, token_idx=%d.", d_i, i));
        }
      }
    }
  }

  void RecordGeneratedToken(std::shared_ptr<InferRequest>& req, int offset, int output_token) {
    std::lock_guard<std::recursive_mutex> guard(g_faked_state.mux_);
    // Compute block offset
    // If block num is not enough for recording, there must be some bug in scheduler
    int block_offset = offset / block_token_num_;
    int offset_in_block = offset % block_token_num_;
    for (int d_i = 0; d_i < device_num_; d_i++) {
      std::vector<int>& kv_cache_blocks = req->kv_cache_blocks[d_i];
      KLLM_CHECK_WITH_INFO(
          kv_cache_blocks.size() > (size_t)block_offset,
          FormatStr("Block not exist. Req id %d, block offset=%d, device_idx=%d, kv_cache_blocks.size()=%d.",
                    req->req_id, block_offset, d_i, kv_cache_blocks.size()));
      int block_idx = kv_cache_blocks[block_offset];
      KLLM_CHECK_WITH_INFO(g_faked_state.device_kv_cache_contents_[d_i].find(block_idx) !=
                               g_faked_state.device_kv_cache_contents_[d_i].end(),
                           FormatStr("Block kv cache content not exist on device %d. Req id %d, block idx=%d.", d_i,
                                     req->req_id, block_idx));
      g_faked_state.device_kv_cache_contents_[d_i][block_idx][offset_in_block] = output_token + d_i;
    }
  }

  struct Statistics {
    int swapout_succ_num = 0;
    int swapout_fail_num = 0;
    int swapin_succ_num = 0;
    int swapin_fail_num = 0;
  };

  const Statistics& GetStatistics() { return stat_; }

 private:
  void CollectKvCacheContentFromDevice(int device_idx, std::vector<int>& block_list, int token_num,
                                       std::vector<int>& kv_cache_contents) {
    std::lock_guard<std::recursive_mutex> guard(g_faked_state.mux_);
    kv_cache_contents.resize(token_num);
    for (int i = 0; i < token_num; i++) {
      int block_offset = i / block_token_num_;
      int offset_in_block = i % block_token_num_;
      KLLM_CHECK_WITH_INFO((size_t)block_offset < block_list.size(),
                           FormatStr("block list on device %d is broken. size=%d, visiting block idx=%d.", device_idx,
                                     block_list.size(), block_offset));
      int block_idx = block_list[block_offset];
      kv_cache_contents[i] = g_faked_state.device_kv_cache_contents_[device_idx][block_idx][offset_in_block];
    }
  }

  void CopyKvCacheContents(const std::vector<int>& src_blks, const std::vector<int>& dst_blks,
                           std::map<int, std::vector<int>>& src_kv_contents,
                           std::map<int, std::vector<int>>& dst_kv_contents) {
    std::lock_guard<std::recursive_mutex> guard(g_faked_state.mux_);
    KLLM_CHECK_WITH_INFO(src_blks.size() <= dst_blks.size(),
                         FormatStr("src_blks.size > dst_blks.size(), %d, %d", src_blks.size(), dst_blks.size()));
    for (size_t i = 0; i < src_blks.size(); i++) {
      int src_blk = src_blks[i];
      int dst_blk = dst_blks[i];
      auto src_content_it = src_kv_contents.find(src_blk);
      auto dst_content_it = dst_kv_contents.find(dst_blk);
      KLLM_CHECK_WITH_INFO(src_content_it != src_kv_contents.end(),
                           FormatStr("Kv cache content of src block %d does not exist", src_blk));
      KLLM_CHECK_WITH_INFO(dst_content_it != dst_kv_contents.end(),
                           FormatStr("Kv cache content of dst block %d does not exist", dst_blk));
      std::ostringstream ss;
      ss << "CopyKvCacheContents from src_block_id " << src_blk << " to dst_block_id " << dst_blk << ". src tokens: ";
      for (int j = 0; j < block_token_num_; j++) {
        dst_content_it->second[j] = src_content_it->second[j];
        ss << dst_content_it->second[j] << ",";
      }
      KLLM_LOG_DEBUG << ss.str();
    }
  }

  void ResetKvCacheContents(const std::vector<int>& blks, std::map<int, std::vector<int>>& kv_contents) {
    std::lock_guard<std::recursive_mutex> guard(g_faked_state.mux_);
    for (auto blk : blks) {
      auto content_it = kv_contents.find(blk);
      KLLM_CHECK_WITH_INFO(content_it != kv_contents.end(),
                           FormatStr("Kv cache content of block %d does not exist", blk));
      for (int i = 0; i < block_token_num_; i++) {
        content_it->second[i] = g_faked_state.DEFAULT_KV_CONTENT;
      }
    }
  }

  BlockManagerConfig block_manager_config_;

  int block_token_num_;
  int host_block_num_;
  int device_block_num_;
  int device_num_;

  WorkspaceMeta dummy_workspace_meta_;

  Statistics stat_;

  std::shared_ptr<MemoryAllocatorInterface> memory_allocator_ = nullptr;

  std::shared_ptr<BlockAllocatorInterface> host_allocator_ = nullptr;
  std::unordered_map<int, std::shared_ptr<BlockAllocatorInterface>> dev_allocators_;
};

inline void GenerateAFakeToken(std::vector<int>& input_tokens, int& output_token, int seed) {
  int sum_of_elems = std::accumulate(input_tokens.begin(), input_tokens.end(), 0);
  output_token = ((sum_of_elems * 0x3e1f9d7) % 0x19de1f3 + seed) % 200000;
}

inline int GetFakeSeed(int offset, const std::vector<std::pair<int, int>>& seeds) {
  KLLM_CHECK_WITH_INFO(offset >= 0, "");
  int seed = -1;
  for (auto& it : seeds) {
    if (offset >= it.first) {
      seed = it.second;
    }
    if (offset < it.first) {
      break;
    }
  }
  return seed;
}

inline int GetFakeEndId() { return -1; }

inline void GenerateFakeTokens(std::vector<int>& input_tokens, int output_token_num, std::vector<int>& output_tokens,
                               bool with_eos, const std::vector<std::pair<int, int>>& seed_list) {
  output_tokens.clear();
  if (input_tokens.size() > 0) {
    output_tokens.resize(input_tokens.size());
    std::copy(input_tokens.begin(), input_tokens.end(), output_tokens.begin());
  }

  // Generate tokens
  if (with_eos) {
    output_token_num--;
  }
  for (int i = 0; i < output_token_num; i++) {
    int output_token;
    GenerateAFakeToken(output_tokens, output_token, GetFakeSeed(input_tokens.size() + i, seed_list));
    output_tokens.push_back(output_token);
  }
  if (with_eos) {
    output_tokens.push_back(GetFakeEndId());
  }
}

inline std::vector<std::shared_ptr<InferRequest>> InitFakeRequest(int req_id, int input_token_num,
                                                                  int expected_output_token_num,
                                                                  std::shared_ptr<Request>& req,
                                                                  const std::vector<std::pair<int, int>>& seeds,
                                                                  size_t tp_num) {
  KLLM_LOG_DEBUG << "Init req " << req_id << ", input_token_num=" << input_token_num
                 << ", expect_output_token_num=" << expected_output_token_num;
  std::shared_ptr<KsanaPythonInput> ksana_python_input = std::make_shared<KsanaPythonInput>();
  ksana_python_input->sampling_config.num_beams = 0;
  ksana_python_input->sampling_config.num_return_sequences = 1;
  auto req_ctx = std::make_shared<std::unordered_map<std::string, std::string>>(
      std::unordered_map<std::string, std::string>{{"key1", "value1"}, {"key2", "value2"}});
  req = std::make_shared<Request>(ksana_python_input, req_ctx);
  req->req_id = req_id;
  req->model_name = "llama";
  req->waiter = std::make_shared<Waiter>(1);
  std::vector<int> dummy_tokens;
  GenerateFakeTokens(dummy_tokens, input_token_num, req->input_tokens, false, seeds);
  req->output_tokens = req->input_tokens;

  std::vector<std::shared_ptr<InferRequest>> infer_req_list;
  for (size_t i = 0; i < req->output_group.size(); i++) {
    std::shared_ptr<InferRequest> infer_req = std::make_shared<InferRequest>(req, i);
    infer_req->sampling_config.stop_token_ids.push_back(GetFakeEndId());
    infer_req->kv_cache_blocks.resize(tp_num);
    infer_req_list.push_back(infer_req);
  }
  return infer_req_list;
}

class BatchSchedulerEnvironmentSimulator {
 public:
  BatchSchedulerEnvironmentSimulator(const BlockManagerConfig& block_manager_config, int tp_num,
                                     std::shared_ptr<FakedBlockAllocatorGroup> block_allocator_group)
      : tp_num_(tp_num) {
    block_allocator_group_ = block_allocator_group;
    block_manager_config_ = block_manager_config;
    ProfilerConfig profiler_config;
  }
  ~BatchSchedulerEnvironmentSimulator() {}

  void RunAStep(std::vector<std::shared_ptr<InferRequest>>& scheduled_reqs) {
    for (auto req : scheduled_reqs) {
      // Note: Not for chunked prefill or speculative decoding
      KLLM_CHECK_WITH_INFO(req->output_tokens.size() == req->forwarding_tokens.size(),
                           FormatStr("output_tokens.size=%d is not equal to forwarding_tokens.size=%d.",
                                     req->output_tokens.size(), req->forwarding_tokens.size()));
      // Generate kv cache content
      // This operation should be done before generation because kv cache is
      // writen during generating next token
      KLLM_LOG_DEBUG << "RunAStep:" << req;
      // Generate kv cache for not cached tokens
      for (size_t i = req->kv_cached_token_num; i < req->output_tokens.size(); ++i) {
        block_allocator_group_->RecordGeneratedToken(req, i, req->output_tokens[i]);
      }

      // Generate a token
      int output_token = GetEndId();
      KLLM_CHECK_WITH_INFO(req_output_num_map_.find(req->req_id) != req_output_num_map_.end(),
                           FormatStr("Req id %d is not exist in req_output_num_map.", req->req_id));
      if ((req->output_tokens.size() - req->input_tokens.size()) < (size_t)(req_output_num_map_[req->req_id] - 1)) {
        std::vector<int> kv_contents;
        // Generate next token based on recorded kv cache content
        // If memory operations break kv cache content, generation results will be wrong
        block_allocator_group_->CollectKvCacheContent(req, kv_contents);
        GenerateAToken(kv_contents, output_token,
                       GetSeed(req->output_tokens.size(), req_generation_seeds_[req->req_id]));

        std::ostringstream ss;
        ss << "GenerateToken " << *req << ", new_generate_token:" << output_token
           << ", kv_contents size: " << kv_contents.size() << ", kv not equal {";
        for (size_t i = 0; i < kv_contents.size(); i++) {
          EXPECT_EQ(req->forwarding_tokens[i], kv_contents[i]);
          if (req->forwarding_tokens[i] != kv_contents[i]) {
            int block_offset = i / block_allocator_group_->block_token_num_;
            int offset_in_block = i % block_allocator_group_->block_token_num_;
            auto& block_list = req->kv_cache_blocks[0];
            int block_idx = block_list[block_offset];
            ss << "kv_cache_not_equal token idx:" << i << ": tokens(" << req->output_tokens[i] << "), fwd("
               << req->forwarding_tokens[i] << ") vs cache(" << kv_contents[i] << "), block_offset:" << block_offset
               << ", offset_in_block:" << offset_in_block << ", block_idx:" << block_idx << ", ";
            std::cout << ss.str() << std::endl;
          }
        }
        ss << "} ";
        KLLM_LOG_DEBUG << ss.str();
      }
      req->sampling_result_tokens.clear();
      req->sampling_result_tokens.emplace_back(output_token);
      req->generated_token = output_token;
    }
    // Assumption: A step is slower than swapout
    std::this_thread::sleep_for(std::chrono::microseconds(2));
  }

  bool IsRequestFinished(std::shared_ptr<InferRequest>& req) { return req->output_tokens.back() == GetEndId(); }

  int GetEndId() { return GetFakeEndId(); }

  std::vector<std::shared_ptr<InferRequest>> InitRequest(int req_id, int input_token_num, int expected_output_token_num,
                                                         std::shared_ptr<Request>& req,
                                                         const std::vector<std::pair<int, int>>& seeds) {
    std::vector<std::shared_ptr<InferRequest>> infer_req_list =
        InitFakeRequest(req_id, input_token_num, expected_output_token_num, req, seeds, tp_num_);
    for (size_t i = 0; i < req->output_group.size(); i++) {
      std::shared_ptr<InferRequest> infer_req = infer_req_list[i];
      SetRequestOutputTokenNum(infer_req, expected_output_token_num);
      SetRequestGenerationSeeds(infer_req->req_id, seeds);
    }
    return infer_req_list;
  }

  void CheckRequestOutput(const std::shared_ptr<InferRequest>& req) {
    // Check request results
    int expected_generate_output_token_num = req_output_num_map_[req->req_id];
    KLLM_LOG_DEBUG << "Checking " << *req << ", input_tokens.size=" << req->input_tokens.size()
                   << ", output_tokens.size=" << req->output_tokens.size()
                   << ", expected_generate_output_token_num=" << expected_generate_output_token_num;
    std::vector<int> expect_output_tokens;
    bool with_eos = true;

    GenerateFakeTokens(req->input_tokens, expected_generate_output_token_num, expect_output_tokens, with_eos,
                       req_generation_seeds_[req->req_id]);
    if (req->finish_status.OK()) {
      EXPECT_EQ(expect_output_tokens.size(), req->output_tokens.size());
      if (expect_output_tokens.size() != req->output_tokens.size()) {
        KLLM_LOG_ERROR << "check size fail " << req;
        std::cerr << "check size fail " << *req << std::endl;
      }
      EXPECT_EQ(expect_output_tokens.size(), req->input_tokens.size() + expected_generate_output_token_num);
      for (size_t i = 0; i < req->output_tokens.size(); i++) {
        EXPECT_EQ(expect_output_tokens[i], req->output_tokens[i]);
        if (expect_output_tokens[i] != req->output_tokens[i]) {
          std::cout << "check token fail req " << *req << ", index = " << i << std::endl;
        }
      }
    }
  }

  const FakedBlockAllocatorGroup::Statistics& GetBlockManagerStat() { return block_allocator_group_->GetStatistics(); }

  const BlockManagerConfig& GetBlockManagerConfig() const { return block_manager_config_; }

  std::shared_ptr<FakedBlockAllocatorGroup> GetBlockAllocatorGroup() { return block_allocator_group_; }

 private:
  void SetRequestOutputTokenNum(std::shared_ptr<InferRequest>& req, int output_token_num) {
    KLLM_CHECK_WITH_INFO(req_output_num_map_.find(req->req_id) == req_output_num_map_.end(),
                         FormatStr("SetRequestOutputTokenNum: Req id %d is already set.", req->req_id));
    req_output_num_map_[req->req_id] = output_token_num;
  }

  void SetRequestGenerationSeeds(int req_id, const std::vector<std::pair<int, int>>& seeds) {
    KLLM_CHECK_WITH_INFO(req_generation_seeds_.find(req_id) == req_generation_seeds_.end(),
                         FormatStr("SetRequestGenerationSeed: Req id %d is already set.", req_id));
    KLLM_CHECK_WITH_INFO(seeds.size() > 0,
                         FormatStr("SetRequestGenerationSeed: Trying to set empty seed for req %d.", req_id));
    KLLM_CHECK_WITH_INFO(seeds[0].first == 0, "SetRequestGenerationSeed: First offset must be 0.");

    if (seeds.size() > 0) {
      int last_offset = 0;
      for (size_t i = 1; i < seeds.size(); i++) {
        KLLM_CHECK_WITH_INFO(seeds[i].first > last_offset, "SetRequestGenerationSeed: Wrong offset order.");
        last_offset = seeds[i].first;
      }
    }

    req_generation_seeds_[req_id] = seeds;
  }

  void GenerateAToken(std::vector<int>& input_tokens, int& output_token, int seed) {
    GenerateAFakeToken(input_tokens, output_token, seed);
  }

  int GetSeed(int offset, const std::vector<std::pair<int, int>>& seeds) { return GetFakeSeed(offset, seeds); }

 private:
  std::shared_ptr<FakedBlockAllocatorGroup> block_allocator_group_ = nullptr;
  std::unordered_map<int, int> req_output_num_map_;

  BlockManagerConfig block_manager_config_;

  // map for seeds used to generate input and output
  // <req_1, {<0, seed1>, <8, seed2>}> means req_1 will use seed1 to generate token from 0 to 8-1, seed2 for 8 to end.
  std::unordered_map<int, std::vector<std::pair<int, int>>> req_generation_seeds_;
  int tp_num_;
};

}  // namespace ksana_llm
